{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hRWOI1nxutyx"
   },
   "source": [
    "# Overview\n",
    "\n",
    "This CodeLab demonstrates how to build a LSTM model for MNIST recognition using Keras, and how to convert it to TensorFlow Lite.\n",
    "\n",
    "The CodeLab is very similar to the `tf.lite.experimental.nn.TFLiteLSTMCell`\n",
    "[CodeLab](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/TensorFlowLite_LSTM_Keras_Tutorial.ipynb). However, with the control flow support in the experimental new converter, we can define the model with control flow directly without refactoring the code.\n",
    "\n",
    "Also note: We're not trying to build the model to be a real world application, but only demonstrate how to use TensorFlow Lite. You can a build a much better model using CNN models. For a more canonical lstm codelab, please see [here](https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wZCbNdY7MNSP"
   },
   "source": [
    "# Step 0: Prerequisites\n",
    "\n",
    "It's recommended to try this feature with the newest TensorFlow nightly pip build."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6Zk2sUHUm5td"
   },
   "outputs": [],
   "source": [
    "!pip install tf-nightly --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R3Ku1Lx9vvfX"
   },
   "source": [
    "\n",
    "## Step 1: Build the MNIST LSTM model.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yQpmCIqJPetJ"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wiYZoDlC5SEJ"
   },
   "outputs": [],
   "source": [
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Input(shape=(28, 28), name='input'),\n",
    "    tf.keras.layers.LSTM(20),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(10, activation=tf.nn.softmax, name='output')\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ff6X9gg_wk7K"
   },
   "source": [
    "## Step 2: Train & Evaluate the model.\n",
    "We will train the model using MNIST data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "23W41fiRPOmh"
   },
   "outputs": [],
   "source": [
    "# Load MNIST dataset.\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "x_train = x_train.astype(np.float32)\n",
    "x_test = x_test.astype(np.float32)\n",
    "\n",
    "# Change this to True if you want to test the flow rapidly.\n",
    "# Train with a small dataset and only 1 epoch. The model will work poorly\n",
    "# but this provides a fast way to test if the conversion works end to end.\n",
    "_FAST_TRAINING = False\n",
    "_EPOCHS = 5\n",
    "if _FAST_TRAINING:\n",
    "  _EPOCHS = 1\n",
    "  _TRAINING_DATA_COUNT = 1000\n",
    "  x_train = x_train[:_TRAINING_DATA_COUNT]\n",
    "  y_train = y_train[:_TRAINING_DATA_COUNT]\n",
    "\n",
    "model.fit(x_train, y_train, epochs=_EPOCHS)\n",
    "model.evaluate(x_test, y_test, verbose=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NtPJGiIQw0nM"
   },
   "source": [
    "## Step 3: Convert the Keras model to TensorFlow Lite model.\n",
    "\n",
    "Note here: we just convert to TensorFlow Lite model as usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Tbuu_8PFz-x_"
   },
   "outputs": [],
   "source": [
    "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
    "tflite_model = converter.convert()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5rHrZkIuxxar"
   },
   "source": [
    "## Step 4: Check the converted TensorFlow Lite model.\n",
    "\n",
    "Now load the TensorFlow Lite model and use the TensorFlow Lite python interpreter to verify the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8lao097MnFf2"
   },
   "outputs": [],
   "source": [
    "# Run the model with TensorFlow to get expected results.\n",
    "expected = model.predict(x_test[0:1])\n",
    "\n",
    "# Run the model with TensorFlow Lite\n",
    "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
    "interpreter.allocate_tensors()\n",
    "input_details = interpreter.get_input_details()\n",
    "output_details = interpreter.get_output_details()\n",
    "interpreter.set_tensor(input_details[0][\"index\"], x_test[0:1, :, :])\n",
    "interpreter.invoke()\n",
    "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
    "\n",
    "# Assert if the result of TFLite model is consistent with the TF model.\n",
    "np.testing.assert_almost_equal(expected, result)\n",
    "print(\"Done. The result of TensorFlow matches the result of TensorFlow Lite.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DWhGUkIs71Qu"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
